import torch
import torch.nn as nn


class DECO_VID(nn.Module):
    def __init__(self, feature_dim=549, num_segments=24, num_frames=16,
                 num_layers=3, num_heads=8, hidden_dim=256, dropout=0.2):
        super().__init__()
        
        # Positional encoding
        self.projection_layer = nn.Linear(549, feature_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, feature_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, 
                                                   nhead=num_heads, 
                                                   dim_feedforward=hidden_dim, 
                                                   dropout=dropout,
                                                   batch_first=True)
        self.temporal_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.segment_queries = nn.Parameter(torch.randn(num_segments, feature_dim))
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=feature_dim, 
                                                   nhead=num_heads, 
                                                   dim_feedforward=hidden_dim, 
                                                   dropout=dropout,
                                                   batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Final classifier to get logits
        self.classifier = nn.Linear(feature_dim, 1)  # Binary logits per segment
        
    def forward(self, x):
        """
        x: input tensor of shape (B, T, C, H, W), e.g., (batch_size, num_frames)
        """

        B, T, D = x.shape

        # Temporal encoding
        x = self.projection_layer(x)
        x = x + self.pos_embedding
        temporal_x = self.temporal_encoder(x)  # (B, T, feature_dim)

        # Transformer decoding: each segment embedding queries temporal features
        queries = self.segment_queries.unsqueeze(0).repeat(B, 1, 1)  # (B, num_segments, feature_dim)

        decoded = self.transformer_decoder(tgt=queries, memory=temporal_x)  # (B, num_segments, feature_dim)

        # Compute logits
        logits = self.classifier(decoded).squeeze(-1)  # (B, num_segments)

        return logits